import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from openfunctions_utils import strip_function_calls, parse_function_call

def get_prompt(user_query: str, functions: list = []) -> str:
    """
    Generates a conversation prompt based on the user's query and a list of functions.

    Parameters:
    - user_query (str): The user's query.
    - functions (list): A list of functions to include in the prompt.

    Returns:
    - str: The formatted conversation prompt.
    """
    system = "You are an AI programming assistant, utilizing the Gorilla LLM model, developed by Gorilla LLM, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer."
    if len(functions) == 0:
        return f"{system}\n### Instruction: <<question>> {user_query}\n### Response: "
    functions_string = json.dumps(functions)
    return f"{system}\n### Instruction: <<function>>{functions_string}\n<<question>>{user_query}\n### Response: "


def format_response(response: str):
    """
    Formats the response from the OpenFunctions model.

    Parameters:
    - response (str): The response generated by the LLM.

    Returns:
    - str: The formatted response.
    - dict: The function call(s) extracted from the response.

    """
    function_call_dicts = None
    try:
        response = strip_function_calls(response)
        # Parallel function calls returned as a str, list[dict]
        if len(response) > 1: 
            function_call_dicts = []
            for function_call in response:
                function_call_dicts.append(parse_function_call(function_call))
            response = ", ".join(response)
        # Single function call returned as a str, dict
        else:
            function_call_dicts = parse_function_call(response[0])
            response = response[0]
    except Exception as e:
        # Just faithfully return the generated response str to the user
        pass
    return response, function_call_dicts

# Device setup
device : str = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Model and tokenizer setup
model_id : str = "gorilla-llm/gorilla-openfunctions-v2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True)

# Move model to device
model.to(device)

# Pipeline setup
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=128,
    batch_size=16,
    torch_dtype=torch_dtype,
    device=device,
)

# Example usage 1
#  This should return 2 functions with the right argument
query_1: str = "What's the weather like in the two cities of Boston and San Francisco?"
functions_1 = [
    {
        "name": "get_current_weather",
        "description": "Get the current weather in a given location",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "The city and state, e.g. San Francisco, CA",
                },
                "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
            },
            "required": ["location"],
        },
    }
]

# Example usage 2
#  This should return an error since the function cann't help with the prompt
query_2: str = "What is the freezing point of water at a pressure of 10 kPa?"
functions_2 = [{"name": "thermodynamics.calculate_boiling_point", "description": "Calculate the boiling point of a given substance at a specific pressure.", "parameters": {"type": "object", "properties": {"substance": {"type": "string", "description": "The substance for which to calculate the boiling point."}, "pressure": {"type": "number", "description": "The pressure at which to calculate the boiling point."}, "unit": {"type": "string", "description": "The unit of the pressure. Default is 'kPa'."}}, "required": ["substance", "pressure"]}}]

# Generate prompt and obtain model output
prompt_1 = get_prompt(query_1, functions=functions_1)
output_1 = pipe(prompt_1)
fn_call_string, function_call_dict = format_response(output_1[0]['generated_text'])
print("--------------------")
print(f"Function call strings 1(s): {fn_call_string}")
print("--------------------")
print(f"OpenAI compatible `function_call`: {function_call_dict}")
print("--------------------")
